class Solution {
public:
    vector<int> p;

    int minMalwareSpread(vector<vector<int>>& graph, vector<int>& initial) {
        int n = graph.size();
        vector<int> size(n, 1);
        for (int i = 0; i < n; ++i) p.push_back(i);
        vector<bool> clean(n, true);
        for (int i : initial) clean[i] = false;
        for (int i = 0; i < n; ++i)
        {
            if (!clean[i]) continue;
            for (int j = i + 1; j < n; ++j)
            {
                if (!clean[j]) continue;
                if (graph[i][j])
                {
                    int pa = find(i), pb = find(j);
                    if (pa == pb) continue;
                    p[pa] = pb;
                    size[pb] += size[pa];
                }
            }
        }
        vector<int> cnt(n, 0);
        unordered_map<int, unordered_set<int>> mp;
        for (int i : initial)
        {
            unordered_set<int> s;
            for (int j = 0; j < n; ++j)
            {
                if (!clean[j]) continue;
                if (graph[i][j]) s.insert(find(j));
            }
            for (int e : s) ++cnt[e];
            mp[i] = s;
        }
        int mx = -1;
        int res = 0;
        for (auto item : mp)
        {
            int i = item.first;
            int t = 0;
            for (int e : item.second)
            {
                if (cnt[e] == 1) t += size[e];
            }
            if (mx < t || (mx == t && i < res))
            {
                mx = t;
                res = i;
            }
        }
        return res;
    }

    int find(int x) {
        if (p[x] != x) p[x] = find(p[x]);
        return p[x];
    }
};